Using slim to Build Model Architectures

This notebook gives a simple example of using Tensorflow's slim library to construct a "deeper" learning architecture for the MNIST dataset. We'll see how easy it is to construct architectures, and how to output simple summaries from Tensorboard.

As a simple example, we'll use the architecture presented in Michael Nielsen's online textbook. This knowledge will later become useful as we explore more complicated architectures like Inception.


In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data as mnist_data

In [2]:
# Read in MNIST dataset, compute mean / standard deviation of the training images
mnist = mnist_data.read_data_sets('MNIST_data', one_hot=True)

MEAN = np.mean(mnist.train.images)
STD = np.std(mnist.train.images)


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [3]:
# Convenience method for reshaping images. The included MNIST dataset stores images
# as Nx784 row vectors. This method reshapes the inputs into Nx28x28x1 images that are
# better suited for convolution operations and rescales the inputs so they have a
# mean of 0 and unit variance.
def resize_images(images):
    reshaped = (images - MEAN)/STD
    reshaped = np.reshape(reshaped, [-1, 28, 28, 1])
    
    assert(reshaped.shape[1] == 28)
    assert(reshaped.shape[2] == 28)
    assert(reshaped.shape[3] == 1)
    
    return reshaped

NielsenNet

The neural net architecture presented by Michael Nielsen in chapter 6 of his textbook achieves an accuracy in excess of 99%. I've dubbed this architecture NielsenNet. It consists of two convolution layers, followed by two fully connected neural network layers, followed by an output layer. Dropout is used to after each fully-connected layer to control overfitting.

When building this model using slim, we'll use the built-in conv2d and max_pool functions to build the convlution layers, changing 28x28 input images into 5x5x40 outputs for the fully connected layer. We'll do this with a combination of different padding modes (SAME vs VALID) and max-pooling.

Finally we'll build a succession of fully-connected layers using the fully_connected convenience method. Dropout can be implemented using slim's dropout method, gated by the is_training tensor.

Building the network by successively mutating the net variable is pretty common, and something we'll see later on in the Inception architecture (peek here).


In [4]:
def nielsen_net(inputs, is_training, scope='NielsenNet'):
    with tf.variable_scope(scope, 'NielsenNet'):
        # First Group: Convolution + Pooling 28x28x1 => 28x28x20 => 14x14x20
        net = slim.conv2d(inputs, 20, [5, 5], padding='SAME', scope='layer1-conv')
        net = slim.max_pool2d(net, 2, stride=2, scope='layer2-max-pool')

        # Second Group: Convolution + Pooling 14x14x20 => 10x10x40 => 5x5x40
        net = slim.conv2d(net, 40, [5, 5], padding='VALID', scope='layer3-conv')
        net = slim.max_pool2d(net, 2, stride=2, scope='layer4-max-pool')

        # Reshape: 5x5x40 => 1000x1
        net = tf.reshape(net, [-1, 5*5*40])

        # Fully Connected Layer: 1000x1 => 1000x1
        net = slim.fully_connected(net, 1000, scope='layer5')
        net = slim.dropout(net, is_training=is_training, scope='layer5-dropout')

        # Second Fully Connected: 1000x1 => 1000x1
        net = slim.fully_connected(net, 1000, scope='layer6')
        net = slim.dropout(net, is_training=is_training, scope='layer6-dropout')

        # Output Layer: 1000x1 => 10x1
        net = slim.fully_connected(net, 10, scope='output')
        net = slim.dropout(net, is_training=is_training, scope='output-dropout')

        return net

In [5]:
sess = tf.InteractiveSession()

# Create the placeholder tensors for the input images (x), the training labels (y_actual)
# and whether or not dropout is active (is_training)
x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name='Inputs')
y_actual = tf.placeholder(tf.float32, shape=[None, 10], name='Labels')
is_training = tf.placeholder(tf.bool, name='IsTraining')

# Pass the inputs into nielsen_net, outputting the logits
logits = nielsen_net(x, is_training, scope='NielsenNetTrain')

In [6]:
# Use the logits to create four additional operations:
#
# 1: The cross entropy of the predictions vs. the actual labels
# 2: The number of correct predictions
# 3: The accuracy given the number of correct predictions
# 4: The update step, using the MomentumOptimizer
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, y_actual))
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_actual, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
train_step = tf.train.MomentumOptimizer(0.01, 0.5).minimize(cross_entropy)

In [7]:
# To monitor our progress using tensorboard, create two summary operations
# to track the loss and the accuracy
loss_summary = tf.summary.scalar('loss', cross_entropy)
accuracy_summary = tf.summary.scalar('accuracy', accuracy)

sess.run(tf.global_variables_initializer())
train_writer = tf.summary.FileWriter('/tmp/nielsen-net', sess.graph)

In [8]:
eval_data = {
    x: resize_images(mnist.validation.images),
    y_actual: mnist.validation.labels,
    is_training: False
}

for i in xrange(100000):
    images, labels = mnist.train.next_batch(100)
    summary, _ = sess.run([loss_summary, train_step], feed_dict={x: resize_images(images), y_actual: labels, is_training: True})
    train_writer.add_summary(summary, i)
    
    if i % 1000 == 0:
        summary, acc = sess.run([accuracy_summary, accuracy], feed_dict=eval_data)
        train_writer.add_summary(summary, i)
        print("Step: %5d, Validation Accuracy = %5.2f%%" % (i, acc * 100))


Step:     0, Validation Accuracy = 10.60%
Step:  1000, Validation Accuracy = 96.74%
Step:  2000, Validation Accuracy = 97.78%
Step:  3000, Validation Accuracy = 98.26%
Step:  4000, Validation Accuracy = 98.42%
Step:  5000, Validation Accuracy = 98.66%
Step:  6000, Validation Accuracy = 98.84%
Step:  7000, Validation Accuracy = 98.88%
Step:  8000, Validation Accuracy = 98.88%
Step:  9000, Validation Accuracy = 98.96%
Step: 10000, Validation Accuracy = 98.94%
Step: 11000, Validation Accuracy = 99.06%
Step: 12000, Validation Accuracy = 99.12%
Step: 13000, Validation Accuracy = 99.10%
Step: 14000, Validation Accuracy = 99.16%
Step: 15000, Validation Accuracy = 99.04%
Step: 16000, Validation Accuracy = 99.20%
Step: 17000, Validation Accuracy = 99.16%
Step: 18000, Validation Accuracy = 99.18%
Step: 19000, Validation Accuracy = 99.16%
Step: 20000, Validation Accuracy = 99.22%
Step: 21000, Validation Accuracy = 99.26%
Step: 22000, Validation Accuracy = 99.24%
Step: 23000, Validation Accuracy = 99.26%
Step: 24000, Validation Accuracy = 99.32%
Step: 25000, Validation Accuracy = 99.34%
Step: 26000, Validation Accuracy = 99.32%
Step: 27000, Validation Accuracy = 99.32%
Step: 28000, Validation Accuracy = 99.34%
Step: 29000, Validation Accuracy = 99.36%
Step: 30000, Validation Accuracy = 99.30%
Step: 31000, Validation Accuracy = 99.36%
Step: 32000, Validation Accuracy = 99.28%
Step: 33000, Validation Accuracy = 99.38%
Step: 34000, Validation Accuracy = 99.38%
Step: 35000, Validation Accuracy = 99.38%
Step: 36000, Validation Accuracy = 99.40%
Step: 37000, Validation Accuracy = 99.36%
Step: 38000, Validation Accuracy = 99.34%
Step: 39000, Validation Accuracy = 99.42%
Step: 40000, Validation Accuracy = 99.38%
Step: 41000, Validation Accuracy = 99.40%
Step: 42000, Validation Accuracy = 99.40%
Step: 43000, Validation Accuracy = 99.42%
Step: 44000, Validation Accuracy = 99.38%
Step: 45000, Validation Accuracy = 99.42%
Step: 46000, Validation Accuracy = 99.40%
Step: 47000, Validation Accuracy = 99.30%
Step: 48000, Validation Accuracy = 99.38%
Step: 49000, Validation Accuracy = 99.30%
Step: 50000, Validation Accuracy = 99.44%
Step: 51000, Validation Accuracy = 99.38%
Step: 52000, Validation Accuracy = 99.42%
Step: 53000, Validation Accuracy = 99.46%
Step: 54000, Validation Accuracy = 99.44%
Step: 55000, Validation Accuracy = 99.42%
Step: 56000, Validation Accuracy = 99.38%
Step: 57000, Validation Accuracy = 99.40%
Step: 58000, Validation Accuracy = 99.44%
Step: 59000, Validation Accuracy = 99.38%
Step: 60000, Validation Accuracy = 99.44%
Step: 61000, Validation Accuracy = 99.38%
Step: 62000, Validation Accuracy = 99.42%
Step: 63000, Validation Accuracy = 99.42%
Step: 64000, Validation Accuracy = 99.42%
Step: 65000, Validation Accuracy = 99.38%
Step: 66000, Validation Accuracy = 99.44%
Step: 67000, Validation Accuracy = 99.38%
Step: 68000, Validation Accuracy = 99.30%
Step: 69000, Validation Accuracy = 99.42%
Step: 70000, Validation Accuracy = 99.40%
Step: 71000, Validation Accuracy = 99.36%
Step: 72000, Validation Accuracy = 99.40%
Step: 73000, Validation Accuracy = 99.38%
Step: 74000, Validation Accuracy = 99.40%
Step: 75000, Validation Accuracy = 99.36%
Step: 76000, Validation Accuracy = 99.40%
Step: 77000, Validation Accuracy = 99.40%
Step: 78000, Validation Accuracy = 99.38%
Step: 79000, Validation Accuracy = 99.38%
Step: 80000, Validation Accuracy = 99.44%
Step: 81000, Validation Accuracy = 99.40%
Step: 82000, Validation Accuracy = 99.44%
Step: 83000, Validation Accuracy = 99.48%
Step: 84000, Validation Accuracy = 99.44%
Step: 85000, Validation Accuracy = 99.46%
Step: 86000, Validation Accuracy = 99.48%
Step: 87000, Validation Accuracy = 99.44%
Step: 88000, Validation Accuracy = 99.42%
Step: 89000, Validation Accuracy = 99.48%
Step: 90000, Validation Accuracy = 99.42%
Step: 91000, Validation Accuracy = 99.42%
Step: 92000, Validation Accuracy = 99.48%
Step: 93000, Validation Accuracy = 99.44%
Step: 94000, Validation Accuracy = 99.44%
Step: 95000, Validation Accuracy = 99.46%
Step: 96000, Validation Accuracy = 99.48%
Step: 97000, Validation Accuracy = 99.44%
Step: 98000, Validation Accuracy = 99.48%
Step: 99000, Validation Accuracy = 99.38%

In [9]:
test_data = {
    x: resize_images(mnist.test.images),
    y_actual: mnist.test.labels,
    is_training: False
}

acc = sess.run(accuracy, feed_dict=test_data)

print("Test Accuracy = %5.2f%%" % (100 * acc))


Test Accuracy = 99.51%

In [ ]: